from __future__ import annotations

import copy
import logging
import os
import re
import warnings

import torch
from transformers.cache_utils import DynamicCache

from ..base import BaseModel
from .prompt import ThymePromptMixin
from .sandbox import execute_code_in_sandbox
from .utils import (
    REASONING_SYS_PROMPT,
    SIMPLE_SYS_PROMPT,
    SPECIAL_STRING_LIST,
    generate_prompt_final_qa,
    generate_prompt_simple_qa,
)


def ensure_image_url(image: str) -> str:
    prefixes = ["http://", "https://", "file://", "data:image;"]
    if any(image.startswith(prefix) for prefix in prefixes):
        return image
    if os.path.exists(image):
        return "file://" + image
    raise ValueError(f"Invalid image: {image}")


def ensure_video_url(video: str) -> str:
    prefixes = ["http://", "https://", "file://", "data:video;"]
    if any(video.startswith(prefix) for prefix in prefixes):
        return video
    if os.path.exists(video):
        return "file://" + video
    raise ValueError(f"Invalid video: {video}")


class Thyme(ThymePromptMixin, BaseModel):
    INSTALL_REQ = False
    INTERLEAVE = True
    VIDEO_LLM = True

    def __init__(
        self,
        model_path: str,
        min_pixels: int | None = None,
        max_pixels: int | None = None,
        max_new_tokens=2048,
        top_p=0.001,
        top_k=1,
        temperature=0.01,
        repetition_penalty=1.0,
        # pandayin: rounds of intermediate steps before reaching final answer.
        max_iterations=5,
        # pandayin: max retry before reaching a valid answer.
        max_retry=5,
        use_custom_prompt: bool = True,
        system_prompt: str | None = "You are a helpful assistant.",
        post_process: bool = True,
        # if True, will try to only extract stuff wrapped in <answer> &
        # </answer>.
        verbose: bool = True,
        auto_cleanup: bool = True,
        # Clean up intermediate images generated by sandbox code.
        **kwargs,
    ):
        from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
        super().__init__(use_custom_prompt=use_custom_prompt)
        self.min_pixels = min_pixels
        self.max_pixels = max_pixels
        self.top_p = top_p
        self.top_k = top_k
        self.temperature = temperature
        self.system_prompt = system_prompt
        self.max_iterations = max_iterations
        self.max_retry = max_retry
        self.verbose = verbose
        self.post_process = post_process
        self.auto_cleanup = auto_cleanup
        self.fps = 2.0
        self.nframe = 64
        self.FRAME_FACTOR = 2
        assert model_path is not None
        self.model_path = model_path
        self.processor = AutoProcessor.from_pretrained(model_path)

        self.generate_kwargs = dict(
            max_new_tokens=max_new_tokens,
            top_p=top_p,
            top_k=top_k,
            temperature=temperature,
            repetition_penalty=repetition_penalty,
            stop_strings=SPECIAL_STRING_LIST,
            eos_token_id=self.processor.tokenizer.eos_token_id,
            tokenizer=self.processor.tokenizer,
        )

        self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            model_path,
            torch_dtype="auto",
            device_map="auto",
            attn_implementation="sdpa",
        )
        self.model.eval()

        torch.cuda.empty_cache()

    def _extract_image_path(self, contents: list[dict[str, str]]):
        user_image_path = ""
        content_history = copy.deepcopy(contents)
        for rou in content_history:
            if rou["type"] != "image":
                continue
            user_image_path = rou["value"]
            break
        return user_image_path

    def _prepare_content(
        self, inputs: list[dict[str, str]], dataset: str | None = None
    ) -> list[dict[str, str]]:
        """
        inputs list[dict[str, str]], each dict has keys: ['type', 'value']
        """
        user_image_path = self._extract_image_path(inputs)
        content = []
        for s in inputs:
            if s["type"] == "image":
                item = {"type": "image", "image": ensure_image_url(s["value"])}
                if dataset == "OCRBench":
                    item["min_pixels"] = 10 * 10 * 28 * 28
                    warnings.warn(
                        f"OCRBench dataset uses custom min_pixels={item['min_pixels']}"
                    )
                    if self.max_pixels is not None:
                        item["max_pixels"] = self.max_pixels
                else:
                    if self.min_pixels is not None:
                        item["min_pixels"] = self.min_pixels
                    if self.max_pixels is not None:
                        item["max_pixels"] = self.max_pixels
            elif s["type"] == "video":
                item = {
                    "type": "video",
                    "video": ensure_video_url(s["value"]),
                    "min_pixels": self.min_pixels,
                    "max_pixels": self.max_pixels,
                }
                if self.fps is not None:
                    item["fps"] = self.fps
                elif self.nframe is not None:
                    import cv2

                    video = cv2.VideoCapture(s["value"])
                    frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
                    video.release()
                    if frame_count < self.nframe:
                        new_frame_count = (
                            frame_count // self.FRAME_FACTOR * self.FRAME_FACTOR)
                        print(f"use {new_frame_count} for {s['value']}")
                        item["nframes"] = new_frame_count
                    else:
                        item["nframes"] = self.nframe
            elif s["type"] == "text":
                item = {
                    "type": "text",
                    "text": generate_prompt_final_qa(
                        s["value"],
                        user_image_path),
                }
            else:
                raise ValueError(f"Invalid message type: {s['type']}, {s}")
            content.append(item)
        return content

    def _prepare_content_simple(
        self, inputs: list[dict[str, str]], dataset: str | None = None
    ) -> list[dict[str, str]]:
        """
        inputs list[dict[str, str]], each dict has keys: ['type', 'value']
        """
        content = []
        for s in inputs:
            if s["type"] == "image":
                item = {"type": "image", "image": ensure_image_url(s["value"])}
                if dataset == "OCRBench":
                    item["min_pixels"] = 10 * 10 * 28 * 28
                    warnings.warn(
                        f"OCRBench dataset uses custom min_pixels={item['min_pixels']}"
                    )
                    if self.max_pixels is not None:
                        item["max_pixels"] = self.max_pixels
                else:
                    if self.min_pixels is not None:
                        item["min_pixels"] = self.min_pixels
                    if self.max_pixels is not None:
                        item["max_pixels"] = self.max_pixels
            elif s["type"] == "video":
                item = {
                    "type": "video",
                    "video": ensure_video_url(s["value"]),
                    "min_pixels": self.min_pixels,
                    "max_pixels": self.max_pixels,
                }
                if self.fps is not None:
                    item["fps"] = self.fps
                elif self.nframe is not None:
                    import cv2

                    video = cv2.VideoCapture(s["value"])
                    frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
                    video.release()
                    if frame_count < self.nframe:
                        new_frame_count = (
                            frame_count // self.FRAME_FACTOR * self.FRAME_FACTOR)
                        print(f"use {new_frame_count} for {s['value']}")
                        item["nframes"] = new_frame_count
                    else:
                        item["nframes"] = self.nframe
            elif s["type"] == "text":
                item = {
                    "type": "text",
                    "text": generate_prompt_simple_qa(
                        s["value"])}
            else:
                raise ValueError(f"Invalid message type: {s['type']}, {s}")
            content.append(item)
        return content

    def _extract_box_answer(self, response):
        resp = response.split("\\boxed{")[-1]
        lt = len(resp)
        counter, end = 1, None
        for i in range(lt):
            if resp[i] == "{":
                counter += 1
            elif resp[i] == "}":
                counter -= 1
            if counter == 0:
                end = i
                break
            elif i == lt - 1:
                end = lt
                break
        if end is not None:
            response = resp[:end]
        return response

    def _remove_unpickable_values(self, dictionary):
        import pickle

        def is_pickable(obj):
            try:
                pickle.dumps(obj)
                return True
            except (pickle.PicklingError, TypeError, AttributeError):
                return False

        keys_to_remove = []
        for key, value in dictionary.items():
            if isinstance(value, dict):
                self._remove_unpickable_values(value)
            elif not is_pickable(value):
                keys_to_remove.append(key)
        for key in keys_to_remove:
            del dictionary[key]
        return dictionary

    def generate_inner_transformers(
            self,
            message,
            dataset=None,
            temp_output_dir=None):
        try:
            from qwen_vl_utils import process_vision_info
        except Exception as err:
            logging.critical(
                "qwen_vl_utils not found, please install it via 'pip install qwen-vl-utils'"
            )  # noqa: E501
            raise err

        user_image_path = self._extract_image_path(message)

        messages = []
        messages.append({"role": "system", "content": REASONING_SYS_PROMPT})
        messages.append(
            {"role": "user", "content": self._prepare_content(message, dataset=dataset)}
        )

        if self.verbose:
            print(f"\033[31m{messages}\033[0m")

        #   -------   outer loop. retry multiple times if fail to reach a valid
        retry_generations = self.max_retry
        has_valid_answer = False
        while (retry_generations > 0) and (not has_valid_answer):
            # pandayin: main logic/ work flow for generation.
            # The gist is to pause at special tokens (</code> & </answer>) and
            # maybe perform code execution.
            conversation_history = copy.deepcopy(messages)

            # For each generation, we initialize a KV-Cache to speed up
            # inference.
            kv_cache = DynamicCache()
            # Maintain a dictionary to save context (local & global vars.) for
            # code execution.
            previous_execution_context = {}
            if self.verbose:
                print(
                    f"\033[32m\n--- Generation {self.max_retry - retry_generations + 1} ---\033[0m"
                )

            #   -------   inner loop. generate multiple steps until reaching a
            retry_iterations = self.max_iterations
            # We assume each answer round is limited to a few code (usually 1)
            # execution.
            while retry_iterations > 0:
                retry_iterations -= 1
                generated_content = []
                if self.verbose:
                    print(
                        f"\033[32m\n--- Iteration {self.max_iterations - retry_iterations} ---\033[0m"
                    )

                text = self.processor.apply_chat_template(
                    [conversation_history], tokenize=False, add_generation_prompt=(
                        retry_iterations == self.max_iterations - 1), )

                if retry_iterations != self.max_iterations - 1:
                    if text[0].endswith("<|im_end|>\n"):
                        text[0] = text[0][: -len("<|im_end|>\n")]
                images, videos = process_vision_info([conversation_history])
                inputs = self.processor(
                    text=text,
                    images=images,
                    videos=videos,
                    padding=True,
                    return_tensors="pt",
                )
                inputs = inputs.to("cuda")

                # just in case this iteration is invalid, we need to roll back,
                # thus making a backup.
                last_kv_cache = copy.deepcopy(kv_cache)
                # bkup context. roll back when we fail to execute the generated
                # code.
                last_execution_context = copy.deepcopy(
                    self._remove_unpickable_values(previous_execution_context)
                )
                generated_ids = self.model.generate(
                    **inputs, **self.generate_kwargs, past_key_values=kv_cache
                )
                generated_ids = [
                    output_ids[len(input_ids):]
                    for input_ids, output_ids in zip(inputs.input_ids, generated_ids)
                ]
                out = self.processor.tokenizer.batch_decode(
                    generated_ids,
                    skip_special_tokens=True,
                    clean_up_tokenization_spaces=False,
                )
                generated_text_segment = out[0]

                # Case 1: directly give answer
                if "</answer>" in generated_text_segment:
                    generated_content.append(
                        {"type": "text", "text": generated_text_segment},
                    )

                # Case 2: reach code generation.
                # parse current result. Two cases: reach </code> or reach
                # </answer>
                code_regex = re.compile(
                    r"<code>\s*(?:```\s*)?(?:python\s*)?([\s\S]*?)\s*(?:```\s*)?</code>",
                    re.IGNORECASE,
                )

                code_match = code_regex.search(generated_text_segment)

                # execute code and return result.
                if code_match:
                    code_to_execute = code_match.group(1).strip()
                    if self.verbose:
                        print(
                            f"\033[31m--- Found Code Block ---\n"
                            f"{generated_text_segment}\n"
                            f",-------------------------\033[0m"
                        )

                    (
                        processed_img_paths,
                        captured_stdout,
                        error_msg,
                        current_execution_context,
                    ) = execute_code_in_sandbox(
                        code_to_execute,
                        user_image_path,
                        temp_output_dir=temp_output_dir,
                        previous_execution_context=previous_execution_context,
                    )
                    previous_execution_context = current_execution_context
                    if not processed_img_paths:
                        # deemed as unsuccessful iteration. roll back status.
                        kv_cache = last_kv_cache
                        previous_execution_context = last_execution_context
                        print(f"{error_msg}")
                        continue

                    has_valid_images = False
                    generated_content += [
                        {"type": "text", "text": generated_text_segment},
                        {"type": "text", "text": "<sandbox_output>"},
                    ]
                    first_path = processed_img_paths[0]
                    if os.path.exists(first_path):
                        # Iterate through each path in the list
                        for img_path in processed_img_paths:
                            if os.path.exists(img_path):
                                # Add text segments only once per sandbox
                                # output block
                                if not has_valid_images:
                                    has_valid_images = True
                                generated_content.append(
                                    {"type": "image", "image": img_path}
                                )
                    else:
                        generated_content.append(
                            {"type": "text", "text": first_path})

                    if has_valid_images or not os.path.exists(first_path):
                        generated_content.append(
                            {"type": "text", "text": "</sandbox_output>"}
                        )
                    else:
                        # pandayin: a failed code execution/generation doesn't
                        # count as an intermedia step.
                        print(
                            "skip this generation due to error and adapt the temperature"
                        )
                        self.generate_kwargs["temperature"] = 1.0
                        continue
                else:
                    # wo code. wo </answer>, assume repetition generated,
                    # break.
                    if "</answer>" not in generated_text_segment:
                        print("wo code. wo </answer>")
                        print(generated_text_segment)
                        self.generate_kwargs["temperature"] = 1.0
                        break

                # Update conversation_history with the latest generated segment
                # If the last message was 'user', start a new 'assistant'
                # message
                if conversation_history[-1]["role"] == "user":
                    conversation_history.append(
                        {"role": "assistant", "content": generated_content}
                    )
                # If the last message was 'assistant', append to its last text
                # content item
                elif conversation_history[-1]["role"] == "assistant":
                    conversation_history[-1]["content"] += generated_content

                # --- Check for final answer tag if no code was processed in this segment ---
                if "</answer>" in generated_text_segment:
                    has_valid_answer = True
                    print("\033[32m--- Final answer tag found. ---\033[0m")
                    break

                # If the model produced an EOS token and no code/answer, it
                # might be finished
                if generated_ids[0][-1] == self.processor.tokenizer.eos_token_id:
                    if self.verbose:
                        print(
                            "\033[32m--- Model generated EOS and no further actions (code/answer)."
                            "Assuming completion. ---\033[0m"
                        )
                    break

            # End of a generation. Maybe successfully find a valid answer, or
            # start a new generation.
            if self.verbose:
                if has_valid_answer:
                    print(
                        f"\033[32m\n--- End of processing (max iterations: {self.max_iterations},"
                        f"actual: {self.max_iterations - retry_iterations + 1}) ---\033[0m"
                    )
                    break
                else:
                    print(
                        f"\033[32m\n --- Fail to find a valid answer. (max retrys: {self.max_retry},"
                        f"actual: {self.max_retry - retry_generations + 1})---\033[0m"
                    )

            retry_generations -= 1
            # pandayin: Adjust/reset generation_kwargs here. So more
            # explorations could be done to find a valid answer.
            print("Fail to find a valid answer and adapt the temperature")
            self.generate_kwargs["temperature"] = 1.0

        # reset generation hyper-param.
        self.generate_kwargs["temperature"] = self.temperature

        # pandayin: If we still fail after max_try generations, try a simple
        # prompt.
        if not has_valid_answer:
            print(
                f"\033[32m\n --- Fail to find a valid answer after {self.max_retry} retrys."
                f"Falling back to simple prompt.---\033[0m"
            )
            del self.generate_kwargs["stop_strings"]

            messages = []
            if self.system_prompt is not None:
                messages.append(
                    {"role": "system", "content": SIMPLE_SYS_PROMPT})
            messages.append(
                {
                    "role": "user",
                    "content": self._prepare_content_simple(message, dataset=dataset),
                }
            )
            conversation_history = copy.deepcopy(messages)
            text = self.processor.apply_chat_template(
                [conversation_history], tokenize=False, add_generation_prompt=True)

            images, videos = process_vision_info([conversation_history])
            inputs = self.processor(
                text=text,
                images=images,
                videos=videos,
                padding=True,
                return_tensors="pt",
            )
            inputs = inputs.to("cuda")
            generated_ids = self.model.generate(
                **inputs,
                **self.generate_kwargs,
            )
            generated_ids = [
                output_ids[len(input_ids):]
                for input_ids, output_ids in zip(inputs.input_ids, generated_ids)
            ]
            out = self.processor.tokenizer.batch_decode(
                generated_ids,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False,
            )
            generated_text_segment = out[0]

            self.generate_kwargs["stop_strings"] = SPECIAL_STRING_LIST

            # to align with the following processing procedure. wrap a <answer>
            # bracket.
            answer_match = re.search(
                r"<answer>(.*?)</answer>", generated_text_segment, re.DOTALL
            )
            if not answer_match:
                generated_text_segment = (
                    "<answer>" + generated_text_segment + "</answer>"
                )
            conversation_history.append(
                {
                    "role": "assistant",
                    "content": [{"type": "text", "text": generated_text_segment}],
                }
            )

        final_assistant_response = ""
        for msg in reversed(conversation_history):
            if msg["role"] != "assistant":
                continue
            current_content_str = ""
            for item in msg["content"]:
                if item["type"] == "text":
                    current_content_str += item["text"]
            # Get the last full response from assistant
            final_assistant_response = current_content_str
            break

        if self.post_process:
            print(
                f"\033[31m--- Final response ---\n{final_assistant_response}\n-------------------------\033[0m"
            )
            # Extract content within <answer> tags from the final assistant
            # response
            answer_match = re.search(
                r"<answer>(.*?)</answer>", final_assistant_response, re.DOTALL
            )
            if answer_match:
                final_answer = answer_match.group(1).strip()
            else:
                final_answer = "No answer tag found in the final output."

            # Sometimes the answer is still wrapped in \boxed{}, keeping the behaviour of Qwen2.5-VL.
            # We extract the answer within this.
            match = re.search(r"\\boxed\{(.*?)\}", final_answer)
            if match:
                final_answer = self._extract_box_answer(final_answer)

            if self.verbose:
                print(f"\033[32m{final_answer}\033[0m")
            return final_answer
        else:
            return final_assistant_response

    def generate_inner(self, message, dataset=None):
        if self.auto_cleanup:
            import tempfile

            with tempfile.TemporaryDirectory() as temp_dir:
                return self.generate_inner_transformers(
                    message, dataset=dataset, temp_output_dir=temp_dir
                )
        else:
            return self.generate_inner_transformers(message, dataset=dataset)
